Stable diffusion代码

2025915

16:36

在看代码前需明确两个关键前提:

  1. 输入对象:代码处理的是「VAE 下采样后的 latent」(而非原始图像)。Stable Diffusion 中,原始图像(如 256×256)会被 VAE 压缩为 1/8 分辨率(如 32×32)、4 通道的 latent,因此代码中所有特征图的分辨率均以「H/8, W/8」为基准。
  2. 核心目标:扩散模型的训练目标是「预测噪声」—— 前向扩散时给干净 latent 加随机噪声,模型需根据「噪声 latent(xt)、时间步(t)、条件(context)」预测出加在上面的真实噪声,反向去噪时用预测噪声恢复干净 latent。

 

总结:关键的是理解UNET_AttentionBlock中的交叉注意力部分,把图片看成是序列的列式,在NLP中序列的长度是一段话中token的数量,在这里序列的长度相当于是图片的H*W,图片的通道数相当于是特征的维度,即token embedding的长度。

 

 

import torch

from torch import nn

from torch.nn import functional as F

from attention import SelfAttention, CrossAttention

 

class TimeEmbedding(nn.Module):

    def __init__(self, n_embd):

        super().__init__()

        self.linear_1 = nn.Linear(n_embd, 4 * n_embd)

        self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)

 

    def forward(self, x): # x: (1, 320) t的初始嵌入(通常是正弦余弦编码

        # x: (1, 320)

 

        # (1, 320) -> (1, 1280)

        x = self.linear_1(x)

       

        # (1, 1280) -> (1, 1280)

        x = F.silu(x)

       

        # (1, 1280) -> (1, 1280)

        x = self.linear_2(x)

 

        return x # 输出:(1, 1280)

 

残差块是 UNET 的「基本特征变换单元」,负责:1)调整特征通道数;2)将时间嵌入注入空间特征;3)通过残差连接避免梯度消失。

class UNET_ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels, n_time=1280):  # n_time=TimeEmbedding输出维度

        super().__init__()

        # 第一分支:特征图处理

        self.groupnorm_feature = nn.GroupNorm(32, in_channels)  # 分组归一化(32组,适合小批量)

        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)  # 3×3卷积(不改变分辨率)

       

        # 第二分支:时间嵌入处理

        self.linear_time = nn.Linear(n_time, out_channels)  # 时间嵌入→与特征通道数一致

       

        # 第三分支:融合后处理

        self.groupnorm_merged = nn.GroupNorm(32, out_channels)

        self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

       

        # 残差连接:若通道数不同,用1×1卷积调整

        self.residual_layer = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1)

 

    def forward(self, feature, time):  # feature: (B, in_ch, H, W); time: (1, 1280)

        residue = feature  # 保存残差

       

        # 1. 处理特征图

        feature = self.groupnorm_feature(feature)  # 归一化

        feature = F.silu(feature)  # 激活

        feature = self.conv_feature(feature)  # 通道数:inout(分辨率不变)

       

        # 2. 处理时间嵌入,并与特征图广播相加

        time = F.silu(time)  # 时间嵌入激活

        time = self.linear_time(time)  # 时间嵌入:1280out_ch(1, out_ch)

        merged = feature + time.unsqueeze(-1).unsqueeze(-1)  # 时间嵌入扩维为(1, out_ch, 1, 1),与特征图广播相加,  时间信息注入:通过time.unsqueeze(-1).unsqueeze(-1)将时间向量(1, out_ch)扩维为(1, out_ch, 1, 1),可与任意分辨率的特征图(B, out_ch, H, W)广播相加,实现「时间信息→空间特征」的融合。

       

        # 3. 融合后进一步处理

        merged = self.groupnorm_merged(merged)

        merged = F.silu(merged)

        merged = self.conv_merged(merged)

       

        # 4. 残差连接:输入残差 + 处理后特征

        return merged + self.residual_layer(residue)

 

 

注意力块是扩散模型捕捉「空间依赖」和「条件依赖」的核心,包含自注意力(Self-Attention) 和交叉注意力(Cross-Attention) :

自注意力:捕捉 latent 内部像素间的空间关联(如 “猫的耳朵” 与 “猫的脸” 的位置关系);

交叉注意力:让 latent 特征与条件信息(如文本 embedding)对齐(如 “红色玫瑰” 的文本 embedding 引导生成红色花瓣)。

 

class UNET_AttentionBlock(nn.Module):

    def __init__(self, n_head: int, n_embd: int, d_context=768):  # d_context=文本embedding维度(如CLIP768

        super().__init__()

        self.channels = n_head * n_embd  # 注意力总维度(多头注意力:head数×单头维度)

       

        # 输入处理

        self.groupnorm = nn.GroupNorm(32, self.channels)

        self.conv_input = nn.Conv2d(self.channels, self.channels, kernel_size=1)  # 1×1卷积调整特征

       

        # 自注意力分支

        self.layernorm_1 = nn.LayerNorm(self.channels)

        self.attention_1 = SelfAttention(n_head, self.channels)  # 自注意力(无外部条件)

       

        # 交叉注意力分支

        self.layernorm_2 = nn.LayerNorm(self.channels)

        self.attention_2 = CrossAttention(n_head, self.channels, d_context)  # 交叉注意力(需context

       

        # FFN分支(GeGLU激活)

        self.layernorm_3 = nn.LayerNorm(self.channels)

        self.linear_geglu_1 = nn.Linear(self.channels, 4 * self.channels * 2)  # 输出分两部分(xgate

        self.linear_geglu_2 = nn.Linear(4 * self.channels, self.channels)

       

        # 输出处理

        self.conv_output = nn.Conv2d(self.channels, self.channels, kernel_size=1)

 

    def forward(self, x, context):  # x: (B, C, H, W); context: (B, SeqLen, 768)(如文本embedding

        residue_long = x  # 保存初始残差(跨整个注意力块)

       

        # 1. 输入归一化与卷积

        x = self.groupnorm(x)

        x = self.conv_input(x)

       

        # 2. 特征图→序列:(B, C, H, W) (B, H×W, C)(注意力需序列格式)

        n, c, h, w = x.shape

        x = x.view(n, c, h * w).transpose(-1, -2)  # 转置后:(B, H×W, C)H×W=序列长度,C=特征维度

       

        # 3. 自注意力(带残差)

        residue_short = x  # 短残差(仅跨自注意力)

        x = self.layernorm_1(x)

        x = self.attention_1(x)  # 自注意力:(B, H×W, C) (B, H×W, C)

        x += residue_short

       

        # 4. 交叉注意力(带残差)

        residue_short = x

        x = self.layernorm_2(x)

        x = self.attention_2(x, context)  # 交叉注意力:用context引导x(如文本→图像特征)

        x += residue_short

       

        # 5. FFNGeGLU激活,带残差)

        residue_short = x

        x = self.layernorm_3(x)

        # GeGLU计算:splitxgatex * gelu(gate)(比ReLU更高效的激活)

        x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)  # 分成两部分:(B, H×W, 4C) (B, H×W, 4C)

        x = x * F.gelu(gate)  # GeGLU核心操作

        x = self.linear_geglu_2(x)  # 4CC

        x += residue_short

       

        # 6. 序列→特征图:(B, H×W, C) (B, C, H, W)

        x = x.transpose(-1, -2).view(n, c, h, w)

       

        # 7. 最终残差连接(初始输入 + 注意力块输出)

        return self.conv_output(x) + residue_long

形状转换原因:注意力机制(如 Transformer)的输入格式是「(批量,序列长度,特征维度)」,而 UNET 的特征是「(批量,通道数,高,宽)」,因此需将「高 × 宽」 flatten 为序列长度,通道数作为特征维度。

CrossAttention 的 context 来源:通常是 CLIP 模型输出的文本 embedding(如 Stable Diffusion 中,文本 “a cat” 会被 CLIP 编码为 (1, 77, 768),77 是最大文本长度,768 是特征维度)。

 

 

 

 

解码器decoder需通过上采样恢复 latent 分辨率(从 H/64→H/8),此模块用「最近邻插值」实现高效上采样,再用 3×3 卷积调整特征,避免插值后的模糊。而encoder部分不需要downsample模块,因为分辨率降低是直接通过conv2d卷积实现的。

 

class Upsample(nn.Module):

    def __init__(self, channels):

        super().__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

 

    def forward(self, x):  # x: (B, C, H, W)

        # 最近邻插值:HW翻倍(如H/32H/16

        x = F.interpolate(x, scale_factor=2, mode='nearest')

        return self.conv(x)  # 3×3卷积:保持通道数,优化插值后的特征

 

 

 

class SwitchSequential(nn.Sequential):

    def forward(self, x, context, time):

        for layer in self:

            if isinstance(layer, UNET_AttentionBlock):

                x = layer(x, context)

            elif isinstance(layer, UNET_ResidualBlock):

                x = layer(x, time)

            else:

                x = layer(x)

        return x

 

 

UNET 是扩散模型的核心网络,采用「编码器 - 瓶颈 - 解码器」的对称结构,通过「下采样(编码器)提取全局特征→瓶颈层处理抽象特征→上采样(解码器)恢复细节 + 跳跃连接(skip connection)补充细节」实现噪声预测。

结构逻辑(结合代码)

UNET 的输入是latent (B, 4, H/8, W/8),输出是(B, 320, H/8, W/8)(后续由UNET_OutputLayer转成 4 通道噪声)。

 

class UNET(nn.Module):

    def __init__(self):

        super().__init__()

#

1. 编码器(encoders):下采样 + 增通道,提取全局特征

编码器通过「ResidualBlock+AttentionBlock」处理特征,再用stride=2的卷积下采样(分辨率减半),同时收集每个层的输出作为「跳跃连接(skip_connections)」,供解码器使用。

        self.encoders = nn.ModuleList([

            # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)

            SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),

           

            # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)

            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),

           

            # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)

            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),

           

            # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)

            SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),

           

            # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)

            SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),

           

            # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)

            SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),

           

            # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)

            SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),

           

            # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)

            SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),

           

            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)

            SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),

           

            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)

            SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),

           

            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            SwitchSequential(UNET_ResidualBlock(1280, 1280)),

           

            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            SwitchSequential(UNET_ResidualBlock(1280, 1280)),

        ])

#

2. 瓶颈层(bottleneck):分辨率最低,处理抽象特征

瓶颈层是编码器的终点、解码器的起点,分辨率最低(H/64)、通道数最高(1280),仅用「ResidualBlock+AttentionBlock+ResidualBlock」处理最抽象的全局特征(计算成本最低,适合全局依赖捕捉)。

        self.bottleneck = SwitchSequential(

            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            UNET_ResidualBlock(1280, 1280),

           

            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            UNET_AttentionBlock(8, 160),

           

            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            UNET_ResidualBlock(1280, 1280),

        )

        #

3. 解码器(decoders):上采样 + 减通道,恢复细节

解码器与编码器对称,每个模块先将「当前特征」与「编码器对应层的 skip_connections」拼接(torch.cat,通道数相加),补充下采样丢失的细节,再通过「ResidualBlock+AttentionBlock」处理,最后用Upsample上采样(分辨率翻倍)。

        self.decoders = nn.ModuleList([

            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            SwitchSequential(UNET_ResidualBlock(2560, 1280)),

           

            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)

            SwitchSequential(UNET_ResidualBlock(2560, 1280)),

           

            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)

            SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),

           

            # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)

            SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),

           

            # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)

            SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),

           

            # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)

            SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),

           

            # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)

            SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),

           

            # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)

            SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),

           

            # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)

            SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),

           

            # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)

            SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),

           

            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)

            SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),

           

            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)

            SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),

        ])

 

    def forward(self, x, context, time):

        # x: (Batch_Size, 4, Height / 8, Width / 8)

        # context: (Batch_Size, Seq_Len, Dim)

        # time: (1, 1280)

 

        skip_connections = []

        for layers in self.encoders:

            x = layers(x, context, time)

            skip_connections.append(x)

 

        x = self.bottleneck(x, context, time)

 

        for layers in self.decoders:

            # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer

            x = torch.cat((x, skip_connections.pop()), dim=1)

            x = layers(x, context, time)

       

        return x

 

将 UNET 解码器输出的 320 通道特征图,转成与输入 latent 一致的 4 通道(预测噪声的通道数需与 latent 相同)。

class UNET_OutputLayer(nn.Module):

    def __init__(self, in_channels, out_channels):  # in=320out=4

        super().__init__()

        self.groupnorm = nn.GroupNorm(32, in_channels)

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

 

    def forward(self, x):  # x: (B, 320, H/8, W/8)

        x = self.groupnorm(x)  # 归一化

        x = F.silu(x)          # 激活

        x = self.conv(x)       # 3204通道(分辨率不变)

        return x  # 输出:(B, 4, H/8, W/8) → 预测的噪声

 

 

 

将「时间嵌入→UNET 特征处理→输出噪声」的全流程封装为Diffusion类,是训练和推理的统一入口。

 

class Diffusion(nn.Module):

    def __init__(self):

        super().__init__()

        self.time_embedding = TimeEmbedding(320)

        self.unet = UNET()

        self.final = UNET_OutputLayer(320, 4)

   

    def forward(self, latent, context, time):

        # latent: (Batch_Size, 4, Height / 8, Width / 8)

        # context: (Batch_Size, Seq_Len, Dim)

        # time: (1, 320)

 

        # (1, 320) -> (1, 1280)

        time = self.time_embedding(time)

       

        # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)

        output = self.unet(latent, context, time)

       

        # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)

        output = self.final(output)

       

        # (Batch, 4, Height / 8, Width / 8)

        return output

 

已使用 OneNote 创建。